{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qg5s5R7AT8Fj"
      },
      "source": [
        "# Multi-agent Simulation\n",
        "\n",
        "This tutorial demonstrates how to run a simple closed-loop simulation with multiple pre-defined sim agents."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MtgRcYqmtMwD"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "from jax import numpy as jnp\n",
        "import numpy as np\n",
        "import mediapy\n",
        "from tqdm import tqdm\n",
        "import dataclasses\n",
        "\n",
        "from waymax import config as _config\n",
        "from waymax import dataloader\n",
        "from waymax import datatypes\n",
        "from waymax import dynamics\n",
        "from waymax import env as _env\n",
        "from waymax import agents\n",
        "from waymax import visualization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dkJwTuSLr0gh"
      },
      "outputs": [],
      "source": [
        "# Config dataset:\n",
        "max_num_objects = 32\n",
        "\n",
        "config = dataclasses.replace(_config.WOD_1_0_0_VALIDATION, max_num_objects=max_num_objects)\n",
        "data_iter = dataloader.simulator_state_generator(config=config)\n",
        "scenario = next(data_iter)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DINmUYg7y-jI"
      },
      "source": [
        "## Initializing and Running the Simulator\n",
        "\n",
        "Waymax uses a Gym-like interface for running closed-loop simulation. \n",
        "\n",
        "The `env.MultiAgentEnvironment` class defines a stateless simulation interface with the two key methods:\n",
        "- The `reset` method initializes and returns the first simulation state.\n",
        "- The `step` method transitions the simulation and takes as arguments a state and an action and outputs the next state.\n",
        "\n",
        "Crucially, the `MultiAgentEnvironment` does not hold any simulation state itself, and the `reset` and `step` functions have no side effects. This allows us to use functional transforms from JAX, such as using jit compilation to optimize the compuation. It also allows the user to arbitrarily branch and restart simulation from any state, or save the simulation by simply serializing and saving the state object.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x7u2dqPCtdeq"
      },
      "outputs": [],
      "source": [
        "# Config the multi-agent environment:\n",
        "init_steps = 11\n",
        "\n",
        "# Set the dynamics model the environment is using.\n",
        "# Note each actor interacting with the environment needs to provide action\n",
        "# compatible with this dynamics model.\n",
        "dynamics_model = dynamics.StateDynamics()\n",
        "\n",
        "# Expect users to control all valid object in the scene.\n",
        "env = _env.MultiAgentEnvironment(\n",
        "    dynamics_model=dynamics_model,\n",
        "    config=dataclasses.replace(\n",
        "        _config.EnvironmentConfig(),\n",
        "        max_num_objects=max_num_objects,\n",
        "        controlled_object=_config.ObjectType.VALID,\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "876iboHbYx2H"
      },
      "source": [
        "We now create a set of sim agents to run in simulation. By default, the behavior of an object that is not controlled is to replay the behavior stored in the dataset (log playback).\n",
        "\n",
        "For each sim agent, we define the algorithm (such as IDM), and specify which objects the agent controls via the `is_controlled_func`, which is required to return a boolean mask marking which objects are being controlled.\n",
        "\n",
        "The IDM agent we use in this example is the `IDMRoutePolicy`, which follows the spatial trajectory stored in the logs, but adjusts the speed profile based on the IDM rule, which will stop or speed up according to the distance between the vehicle and any objects in front of the vehicle. For the remaining agents, we set them to use a constant speed policy which will follow the logged route with a fixed, constant speed."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IfCHlgJzghUS"
      },
      "outputs": [],
      "source": [
        "# Setup a few actors, see visualization below for how each actor behaves.\n",
        "\n",
        "# An actor that doesn't move, controlling all objects with index \u003e 4\n",
        "obj_idx = jnp.arange(max_num_objects)\n",
        "static_actor = agents.create_constant_speed_actor(\n",
        "    speed=0.0,\n",
        "    dynamics_model=dynamics_model,\n",
        "    is_controlled_func=lambda state: obj_idx \u003e 4,\n",
        ")\n",
        "\n",
        "# IDM actor/policy controlling both object 0 and 1.\n",
        "# Note IDM policy is an actor hard-coded to use dynamics.StateDynamics().\n",
        "actor_0 = agents.IDMRoutePolicy(\n",
        "    is_controlled_func=lambda state: (obj_idx == 0) | (obj_idx == 1)\n",
        ")\n",
        "\n",
        "# Constant speed actor with predefined fixed speed controlling object 2.\n",
        "actor_1 = agents.create_constant_speed_actor(\n",
        "    speed=5.0,\n",
        "    dynamics_model=dynamics_model,\n",
        "    is_controlled_func=lambda state: obj_idx == 2,\n",
        ")\n",
        "\n",
        "# Exper/log actor controlling objects 3 and 4.\n",
        "actor_2 = agents.create_expert_actor(\n",
        "    dynamics_model=dynamics_model,\n",
        "    is_controlled_func=lambda state: (obj_idx == 3) | (obj_idx == 4),\n",
        ")\n",
        "\n",
        "actors = [static_actor, actor_0, actor_1, actor_2]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JvV82vOXWbI0"
      },
      "source": [
        "We can (optionally) jit the step and select action functions to speed up computation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YmzEJy2LUwe5"
      },
      "outputs": [],
      "source": [
        "jit_step = jax.jit(env.step)\n",
        "jit_select_action_list = [jax.jit(actor.select_action) for actor in actors]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iBeI2mdIUdTw"
      },
      "source": [
        "We can now write a for loop to all of these agents in simulation together.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "06SjvXdRrV3N"
      },
      "outputs": [],
      "source": [
        "states = [env.reset(scenario)]\n",
        "for _ in range(states[0].remaining_timesteps):\n",
        "  current_state = states[-1]\n",
        "\n",
        "  outputs = [\n",
        "      jit_select_action({}, current_state, None, None)\n",
        "      for jit_select_action in jit_select_action_list\n",
        "  ]\n",
        "  action = agents.merge_actions(outputs)\n",
        "  next_state = jit_step(current_state, action)\n",
        "\n",
        "  states.append(next_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s6TmqhLRGc_7"
      },
      "source": [
        "## Visualization of simulation.\n",
        "\n",
        "We can now visualize the result of the simulation loop.\n",
        "\n",
        "On the left side:\n",
        "- Objects 5, 6, and 7 (controlled by static_actor) remain static.\n",
        "- Objects 3 and 4 controlled by log playback, and collide with objects 5 and 6.\n",
        "\n",
        "On the right side:\n",
        "- Object 2 controlled by actor_1 is moving at constant speed 5m/s (i.e. slower than log in this case).\n",
        "- Object 0 and 1, controlled by the IDM agent, follow the log in the beginning, but object 1 slows down when approaching object 2."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cYWjd4dE1bB2"
      },
      "outputs": [],
      "source": [
        "imgs = []\n",
        "for state in states:\n",
        "  imgs.append(visualization.plot_simulator_state(state, use_log_traj=False))\n",
        "mediapy.show_video(imgs, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "liRaNVbE1gWb"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "multi_actors_demo.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "14w5MbrMNLsOsLuD5kXy5-rrNO3ZgsHat",
          "timestamp": 1678404744504
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
